[Mlir-commits] [mlir] [mlir][sparse] support sparsifying batch levels (PR #83898)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 4 11:32:08 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-sparse
@llvm/pr-subscribers-mlir
Author: Peiming Liu (PeimingLiu)
<details>
<summary>Changes</summary>
---
Patch is 35.56 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/83898.diff
14 Files Affected:
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h (+4-1)
- (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h (+4-1)
- (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+3-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp (+4-2)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp (+32-12)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp (+12-11)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (+1-1)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h (+4-2)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp (+90-43)
- (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h (+15-3)
- (added) mlir/test/Dialect/SparseTensor/sparse_batch.mlir (+48)
- (modified) mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir (+2-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
index ce34ae43d1c181..7aa9cb6119434b 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorStorageLayout.h
@@ -42,7 +42,10 @@ namespace sparse_tensor {
///
/// struct sparse_tensor.storage_specifier {
/// array<rank x int> lvlSizes ; sizes/cardinalities for each level
-/// array<n x int> memSizes; ; sizes/lengths for each data memref
+/// // TODO: memSizes need to be expanded to array<[batch] x n x int> to
+/// // support different sizes for different batches. At the moment, we
+/// // assume that every batch occupies the same memory size.
+/// array<n x int> memSizes ; sizes/lengths for each data memref
/// }
/// };
///
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
index c93a4fcd922c28..930922e053d441 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
@@ -253,7 +253,10 @@ class SparseTensorType {
CrdTransDirectionKind::dim2lvl);
}
- /// Returns the Level-shape.
+ /// Returns the batched level rank.
+ unsigned getBatchLvlRank() const { return getEncoding().getBatchLvlRank(); }
+
+ /// Returns the batched Level-shape.
SmallVector<Size> getBatchLvlShape() const {
auto lvlShape = getEncoding().tranlateShape(getDimShape(),
CrdTransDirectionKind::dim2lvl);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 232635ca84a47e..ca91f18cfa4348 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -374,7 +374,7 @@ Level SparseTensorEncodingAttr::getLvlRank() const {
LevelType SparseTensorEncodingAttr::getLvlType(Level l) const {
if (!getImpl())
- return LevelFormat::Dense;
+ return LevelFormat::Batch;
assert(l < getLvlRank() && "Level is out of bounds");
return getLvlTypes()[l];
}
@@ -1755,6 +1755,8 @@ LogicalResult ConcatenateOp::verify() {
LogicalResult InsertOp::verify() {
const auto stt = getSparseTensorType(getTensor());
+ if (stt.getEncoding().getBatchLvlRank() > 0)
+ return emitOpError("batched sparse tensor insertion not implemented");
if (stt.getLvlRank() != static_cast<Level>(getLvlCoords().size()))
return emitOpError("incorrect number of coordinates");
return success();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
index cd6b9b49893731..b39a2d9c57d8b0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseAssembler.cpp
@@ -33,7 +33,8 @@ static void convTypes(TypeRange types, SmallVectorImpl<Type> &convTypes,
}
// Convert the external representation of the values array.
const SparseTensorType stt(cast<RankedTensorType>(type));
- auto shape = {ShapedType::kDynamic};
+ auto shape = stt.getBatchLvlShape();
+ shape.push_back(ShapedType::kDynamic);
auto vtp = RankedTensorType::get(shape, stt.getElementType());
convTypes.push_back(vtp);
if (extraTypes)
@@ -72,7 +73,8 @@ static void convVals(OpBuilder &builder, Location loc, TypeRange types,
// Convert the external representation of the values array.
auto rtp = cast<RankedTensorType>(type);
const SparseTensorType stt(rtp);
- auto shape = {ShapedType::kDynamic};
+ auto shape = stt.getBatchLvlShape();
+ shape.push_back(ShapedType::kDynamic);
SmallVector<Value> inputs;
SmallVector<Type> retTypes;
SmallVector<Type> cntTypes;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
index 4e3393195813c3..5da8a60d2d5fb0 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
@@ -429,11 +429,18 @@ static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem,
}
/// Creates the reassociation array.
-static ReassociationIndices getReassociationForFlattening(ShapedType srcTp) {
- ReassociationIndices reassociation;
- for (int i = 0, e = srcTp.getRank(); i < e; i++)
- reassociation.push_back(i);
- return reassociation;
+static SmallVector<ReassociationIndices>
+getReassociationForFlattening(ShapedType srcTp, unsigned batchLvls) {
+ SmallVector<ReassociationIndices> ret(batchLvls + 1, {});
+ // Create reassociation in form:
+ // {0}, {1}, ..., {batchLvl - 1}, {batchLvl, ..., rank}
+ for (unsigned i = 0; i < batchLvls; i++)
+ ret[i].push_back(i);
+
+ for (int i = batchLvls, e = srcTp.getRank(); i < e; i++)
+ ret.back().push_back(i);
+
+ return ret;
}
//===----------------------------------------------------------------------===//
@@ -1287,9 +1294,10 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
: op.getLevels()[fIdx];
// TODO: handle batch.
TypedValue<BaseMemRefType> mem = genToMemref(rewriter, loc, tensor);
- if (mem.getType().getRank() > 1) {
- // Flattens the buffer to rank 1.
- auto reassoc = getReassociationForFlattening(mem.getType());
+ if (mem.getType().getRank() > stt.getBatchLvlRank() + 1) {
+ // Flattens the buffer to batchLvlRank.
+ auto reassoc = getReassociationForFlattening(
+ mem.getType(), stt.getBatchLvlRank());
mem = rewriter.create<memref::CastOp>(
loc, fType,
rewriter.create<memref::CollapseShapeOp>(loc, mem, reassoc));
@@ -1325,11 +1333,17 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
// Sets up the memory size by reading the last value in position array.
LevelType lt = stt.getLvlType(lvl);
// Simply forwards the position index when this is a dense level.
- if (isDenseLT(lt)) {
+ if (lt.isa<LevelFormat::Dense>()) {
memSize = rewriter.create<arith::MulIOp>(loc, lvlSize, memSize);
posBack = rewriter.create<arith::SubIOp>(loc, memSize, c1);
continue;
}
+ if (lt.isa<LevelFormat::Batch>()) {
+ // Skips batch levels as it is not linearized.
+ // FIXME: this assumes that every batch has the same number of nse, need
+ // to be generalized to handle varied-size batches.
+ continue;
+ }
if (isWithPosLT(lt)) {
assert(isCompressedLT(lt) || isLooseCompressedLT(lt));
@@ -1343,7 +1357,12 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
}
desc.setPosMemSize(rewriter, loc, lvl, memSize);
// The last value in position array is the memory size for next level.
- memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), posBack);
+ // FIXME: this assumes that every batch has the same number of nse, need
+ // to be generalized to handle varied-size batches.
+ SmallVector<Value> batched(stt.getBatchLvlRank(),
+ constantIndex(rewriter, loc, 0));
+ batched.push_back(posBack);
+ memSize = genIndexLoad(rewriter, loc, desc.getPosMemRef(lvl), batched);
posBack = rewriter.create<arith::SubIOp>(loc, posBack, c1);
}
assert(isWithCrdLT(lt) && lvl <= trailCOOStart);
@@ -1413,8 +1432,9 @@ struct SparseDisassembleOpConverter
retLen.push_back(genScalarToTensor(rewriter, loc, sz, lvlLenTp));
}
Value flatOut = dst;
- if (dst.getType().getRank() != 1) {
- auto reassoc = getReassociationForFlattening(dst.getType());
+ if (dst.getType().getRank() > stt.getBatchLvlRank() + 1) {
+ auto reassoc =
+ getReassociationForFlattening(dst.getType(), stt.getBatchLvlRank());
flatOut = rewriter.create<memref::CollapseShapeOp>(loc, dst, reassoc);
}
Value dstMem = genSliceToSize(rewriter, loc, flatOut, sz);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 6ff21468e05764..5150615af180c8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -1221,7 +1221,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
}
Value vals = loopEmitter.getValBuffer()[0];
- Value pos = loopEmitter.getValPosits(0);
+ SmallVector<Value> pos = loopEmitter.getValPosits(0);
// Loads the value from sparse tensor using position-index;
// loads the value from dense tensor using coords.
Value val = enc ? rewriter.create<memref::LoadOp>(loc, vals, pos)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 8f2ae60b311f7c..1fb70ed5035c03 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -86,7 +86,7 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
case AffineExprKind::Add:
case AffineExprKind::Mul:
case AffineExprKind::Constant: {
- assert(isDenseLT(lt));
+ assert(lt.hasDenseSemantic());
if (auto binOp = dyn_cast<AffineBinaryOpExpr>(a)) {
// We do not set dim level format for affine expression like d0 + d1 on
// either loop index at d0 or d1. We continue the recursion merely to
@@ -211,7 +211,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
"AffineMap does not have dimension-rank many results");
unsigned num = 0;
for (Level l = 0; l < lvlRank; l++) {
- if (!isa<AffineDimExpr>(exprs[l]) && !stt.isDenseLvl(l))
+ if (!isa<AffineDimExpr>(exprs[l]) && !stt.getLvlType(l).hasDenseSemantic())
num++;
}
return num;
@@ -355,8 +355,8 @@ static Value genSubscript(CodegenEnv &env, OpBuilder &builder, OpOperand *t,
if (stt.hasEncoding()) {
// For sparse tensors we only push the last-level's position onto `args`.
const auto pos = env.emitter().getValPosits(tid);
- assert(pos);
- args.push_back(pos);
+ assert(!pos.empty());
+ args.append(pos);
} else {
// For dense tensors we push all level's coordinates onto `args`.
const Level lvlRank = stt.getLvlRank();
@@ -801,7 +801,7 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
// `CodegenEnv::lt(TensorId, LoopId)`. The returned LT from CodegenEnv
// should be consistent with the LT indexed by <TensorId, Level>.
const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, curr);
- return isCompressedLT(lt) || isSingletonLT(lt);
+ return lt.hasSparseSemantic();
});
return isParallelFor(env, /*isOuter=*/curr == 0, isSparse);
}
@@ -890,15 +890,14 @@ static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
}
assert(curr == env.merger().loop(b));
Value clause;
- if (isCompressedLT(lt) || isSingletonLT(lt) ||
- isLooseCompressedLT(lt) || isNOutOfMLT(lt)) {
+ if (lt.hasSparseSemantic()) {
assert(lvl.has_value());
const Value crd = env.emitter().getCoord(tid, *lvl);
const Value lvar = env.getLoopVar(curr);
clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
crd, lvar);
} else {
- assert(isDenseLT(lt) || isUndefLT(lt));
+ assert(lt.hasDenseSemantic() || isUndefLT(lt));
clause = constantI1(builder, loc, true);
}
cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause;
@@ -988,7 +987,7 @@ static bool getAllTidLvlsInLatPoints(
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
- } else if (isDenseLT(lt) || isIdxReduc) {
+ } else if (lt.hasDenseSemantic() || isIdxReduc) {
callback(env.makeTensorLevel(tid, *lvl), nullptr);
} else {
assert(isUndefLT(lt));
@@ -1010,7 +1009,8 @@ static bool getAllTidLvlsInLatPoints(
AffineExpr exp = affines[l];
// Skip simple affine expression and non-dense levels (which
// have their own filter loop).
- if (isa<AffineDimExpr>(exp) || !stt.isDenseLvl(l))
+ LevelType lt = stt.getLvlType(l);
+ if (isa<AffineDimExpr>(exp) || !lt.hasDenseSemantic())
continue;
// Constant affine expression are handled in genLoop.
@@ -1103,7 +1103,8 @@ static void genConstantDenseAddressFromLevel(CodegenEnv &env,
assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
for (Level l = startLvl; l < lvlRank; l++) {
AffineExpr lvlExpr = lvlExprs[l];
- if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
+ if (enc.getLvlType(l).hasDenseSemantic() &&
+ isa<AffineConstantExpr>(lvlExpr))
env.emitter().locateLvlAtAffineAddress(
builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
else
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
index fa570159ba41ca..89af75dea2a0f2 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.cpp
@@ -175,7 +175,7 @@ Value sparse_tensor::genScalarToTensor(OpBuilder &builder, Location loc,
}
Value sparse_tensor::genIndexLoad(OpBuilder &builder, Location loc, Value mem,
- Value s) {
+ ValueRange s) {
Value load = builder.create<memref::LoadOp>(loc, mem, s);
if (!isa<IndexType>(load.getType())) {
if (load.getType().getIntOrFloatBitWidth() < 64)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index e8f6bd1c5eaeb1..ce5831d999e9a4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -149,7 +149,7 @@ Value genScalarToTensor(OpBuilder &builder, Location loc, Value elem,
/// Generates a pointer/index load from the sparse storage scheme. Narrower
/// data types need to be zero extended before casting the value into the
/// index type used for looping and indexing.
-Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, Value s);
+Value genIndexLoad(OpBuilder &builder, Location loc, Value mem, ValueRange s);
/// Generates a 1-valued attribute of the given type. This supports
/// all the same types as `getZeroAttr`; however, unlike `getZeroAttr`,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
index 7bfe713cdd9f74..b5a0ac8484abdd 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h
@@ -220,9 +220,11 @@ class LoopEmitter {
///
/// Getters.
///
- Value getValPosits(TensorId tid) const {
+ SmallVector<Value> getValPosits(TensorId tid) const {
+ SmallVector<Value> batchCrds = iters[tid].back().back()->getBatchCrds();
Value lastLvlPos = iters[tid].back().back()->getCurPosition().first;
- return lastLvlPos;
+ batchCrds.push_back(lastLvlPos);
+ return batchCrds;
};
Value getCoord(TensorId tid, Level lvl) const {
return getCurIterator(tid, lvl).getCrd();
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
index 8edacaa9981ef8..a456c87445eafc 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp
@@ -52,8 +52,11 @@ class SparseLevel : public SparseTensorLevel {
Value crdBuffer)
: SparseTensorLevel(tid, lvl, lt, lvlSize), crdBuffer(crdBuffer) {}
- Value peekCrdAt(OpBuilder &b, Location l, Value iv) const override {
- return genIndexLoad(b, l, crdBuffer, iv);
+ Value peekCrdAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ Value iv) const override {
+ SmallVector<Value> memCrd(batchPrefix);
+ memCrd.push_back(iv);
+ return genIndexLoad(b, l, crdBuffer, memCrd);
}
protected:
@@ -62,26 +65,35 @@ class SparseLevel : public SparseTensorLevel {
class DenseLevel : public SparseTensorLevel {
public:
- DenseLevel(unsigned tid, Level lvl, Value lvlSize, bool encoded)
- : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize),
- encoded(encoded) {}
+ DenseLevel(unsigned tid, Level lvl, Value lvlSize)
+ : SparseTensorLevel(tid, lvl, LevelFormat::Dense, lvlSize) {}
- Value peekCrdAt(OpBuilder &, Location, Value pos) const override {
- return pos;
+ Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
+ llvm_unreachable("locate dense level instead");
}
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
+ ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
+ Value max) const override {
+ Value posLo = MULI(p, lvlSize);
+ return {posLo, lvlSize};
+ }
+};
+
+class BatchLevel : public SparseTensorLevel {
+public:
+ BatchLevel(unsigned tid, Level lvl, Value lvlSize)
+ : SparseTensorLevel(tid, lvl, LevelFormat::Batch, lvlSize) {}
+
+ Value peekCrdAt(OpBuilder &, Location, ValueRange, Value) const override {
+ llvm_unreachable("locate dense level instead");
+ }
+
+ ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange, Value p,
Value max) const override {
assert(max == nullptr && "Dense level can not be non-unique.");
- if (encoded) {
- Value posLo = MULI(p, lvlSize);
- return {posLo, lvlSize};
- }
// No need to linearize the position for non-annotated tensors.
return {C_IDX(0), lvlSize};
}
-
- const bool encoded;
};
class CompressedLevel : public SparseLevel {
@@ -90,14 +102,17 @@ class CompressedLevel : public SparseLevel {
Value posBuffer, Value crdBuffer)
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- if (max == nullptr) {
- Value pLo = genIndexLoad(b, l, posBuffer, p);
- Value pHi = genIndexLoad(b, l, posBuffer, ADDI(p, C_IDX(1)));
- return {pLo, pHi};
- }
- llvm_unreachable("compressed-nu should be the first non-unique level.");
+ ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ Value p, Value max) const override {
+ assert(max == nullptr &&
+ "compressed level must be the first non-unique level.");
+
+ SmallVector<Value> memCrd(batchPrefix);
+ memCrd.push_back(p);
+ Value pLo = genIndexLoad(b, l, posBuffer, memCrd);
+ memCrd.back() = ADDI(p, C_IDX(1));
+ Value pHi = genIndexLoad(b, l, posBuffer, memCrd);
+ return {pLo, pHi};
}
private:
@@ -110,12 +125,17 @@ class LooseCompressedLevel : public SparseLevel {
Value posBuffer, Value crdBuffer)
: SparseLevel(tid, lvl, lt, lvlSize, crdBuffer), posBuffer(posBuffer) {}
- ValuePair peekRangeAt(OpBuilder &b, Location l, Value p,
- Value max) const override {
- assert(max == nullptr && "loss compressed level can not be non-unique.");
+ ValuePair peekRangeAt(OpBuilder &b, Location l, ValueRange batchPrefix,
+ Value p, Value max) const override {
+ assert(max == nullptr &&
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/83898
More information about the Mlir-commits
mailing list